In [1]:
# Lib used
import io
import time
import copy
import urllib
import requests

%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams.update({'font.size': 22})

!pip install -U scikit-image # need upgraded version
import skimage
import skimage.io
import skimage.filters
import skimage.transform
import skimage.morphology
import skimage.segmentation

from PIL import Image
import cv2

import numpy as np

import scipy
import scipy.ndimage
from scipy.stats.kde import gaussian_kde

from IPython import display

import torch
import torch.nn as nn
import torchvision
import torchvision.models as models
from torch.optim import Adam

from torchvision import transforms
from torch.autograd import Variable

!pip install skan
import skan
import skan.draw

axisoff = lambda ax: [x.axis('off') for x in ax.ravel()]
Collecting scikit-image
  Downloading scikit_image-0.19.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (14.0 MB)
     |████████████████████████████████| 14.0 MB 4.7 MB/s eta 0:00:01
Requirement already satisfied, skipping upgrade: scipy>=1.4.1 in /home/catherine/.local/lib/python3.8/site-packages (from scikit-image) (1.7.2)
Requirement already satisfied, skipping upgrade: PyWavelets>=1.1.1 in /home/catherine/.local/lib/python3.8/site-packages (from scikit-image) (1.2.0)
Requirement already satisfied, skipping upgrade: imageio>=2.4.1 in /home/catherine/.local/lib/python3.8/site-packages (from scikit-image) (2.10.4)
Requirement already satisfied, skipping upgrade: networkx>=2.2 in /home/catherine/.local/lib/python3.8/site-packages (from scikit-image) (2.6.3)
Requirement already satisfied, skipping upgrade: packaging>=20.0 in /home/catherine/.local/lib/python3.8/site-packages (from scikit-image) (21.2)
Requirement already satisfied, skipping upgrade: numpy>=1.17.0 in /home/catherine/.local/lib/python3.8/site-packages (from scikit-image) (1.21.4)
Requirement already satisfied, skipping upgrade: pillow!=7.1.0,!=7.1.1,!=8.3.0,>=6.1.0 in /home/catherine/.local/lib/python3.8/site-packages (from scikit-image) (8.4.0)
Requirement already satisfied, skipping upgrade: tifffile>=2019.7.26 in /home/catherine/.local/lib/python3.8/site-packages (from scikit-image) (2021.11.2)
Requirement already satisfied, skipping upgrade: pyparsing<3,>=2.0.2 in /home/catherine/.local/lib/python3.8/site-packages (from packaging>=20.0->scikit-image) (2.4.7)
Installing collected packages: scikit-image
  Attempting uninstall: scikit-image
    Found existing installation: scikit-image 0.18.3
    Uninstalling scikit-image-0.18.3:
      Successfully uninstalled scikit-image-0.18.3
Successfully installed scikit-image-0.19.2
Collecting skan
  Downloading skan-0.10.0-py3-none-any.whl (1.5 MB)
     |████████████████████████████████| 1.5 MB 10.5 MB/s eta 0:00:01
Requirement already satisfied: matplotlib>=3.0 in /home/catherine/.local/lib/python3.8/site-packages (from skan) (3.4.3)
Requirement already satisfied: scikit-image>=0.17 in /home/catherine/.local/lib/python3.8/site-packages (from skan) (0.19.2)
Collecting toolz>=0.10.0
  Downloading toolz-0.11.2-py3-none-any.whl (55 kB)
     |████████████████████████████████| 55 kB 8.5 MB/s  eta 0:00:01
Collecting numpydoc>=0.9.2
  Downloading numpydoc-1.2-py3-none-any.whl (51 kB)
     |████████████████████████████████| 51 kB 17.2 MB/s eta 0:00:01
Requirement already satisfied: pandas>=1.0 in /home/catherine/.local/lib/python3.8/site-packages (from skan) (1.3.4)
Collecting openpyxl>=2.4
  Downloading openpyxl-3.0.9-py2.py3-none-any.whl (242 kB)
     |████████████████████████████████| 242 kB 67.5 MB/s eta 0:00:01
Requirement already satisfied: networkx>=2.0 in /home/catherine/.local/lib/python3.8/site-packages (from skan) (2.6.3)
Requirement already satisfied: imageio>=2.0 in /home/catherine/.local/lib/python3.8/site-packages (from skan) (2.10.4)
Requirement already satisfied: scipy>=1.2.0 in /home/catherine/.local/lib/python3.8/site-packages (from skan) (1.7.2)
Collecting tqdm>=4.56.0
  Downloading tqdm-4.63.0-py2.py3-none-any.whl (76 kB)
     |████████████████████████████████| 76 kB 10.7 MB/s  eta 0:00:01
Collecting numba>=0.50
  Downloading numba-0.55.1-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (3.4 MB)
     |████████████████████████████████| 3.4 MB 79.0 MB/s eta 0:00:01
Requirement already satisfied: numpy>=1.16.5 in /home/catherine/.local/lib/python3.8/site-packages (from skan) (1.21.4)
Requirement already satisfied: pillow>=6.2.0 in /home/catherine/.local/lib/python3.8/site-packages (from matplotlib>=3.0->skan) (8.4.0)
Requirement already satisfied: cycler>=0.10 in /home/catherine/.local/lib/python3.8/site-packages (from matplotlib>=3.0->skan) (0.11.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /home/catherine/.local/lib/python3.8/site-packages (from matplotlib>=3.0->skan) (1.3.2)
Requirement already satisfied: python-dateutil>=2.7 in /usr/lib/python3/dist-packages (from matplotlib>=3.0->skan) (2.7.3)
Requirement already satisfied: pyparsing>=2.2.1 in /home/catherine/.local/lib/python3.8/site-packages (from matplotlib>=3.0->skan) (2.4.7)
Requirement already satisfied: PyWavelets>=1.1.1 in /home/catherine/.local/lib/python3.8/site-packages (from scikit-image>=0.17->skan) (1.2.0)
Requirement already satisfied: tifffile>=2019.7.26 in /home/catherine/.local/lib/python3.8/site-packages (from scikit-image>=0.17->skan) (2021.11.2)
Requirement already satisfied: packaging>=20.0 in /home/catherine/.local/lib/python3.8/site-packages (from scikit-image>=0.17->skan) (21.2)
Requirement already satisfied: Jinja2>=2.10 in /home/catherine/.local/lib/python3.8/site-packages (from numpydoc>=0.9.2->skan) (3.0.3)
Collecting sphinx>=1.8
  Downloading Sphinx-4.4.0-py3-none-any.whl (3.1 MB)
     |████████████████████████████████| 3.1 MB 106.5 MB/s eta 0:00:01
Requirement already satisfied: pytz>=2017.3 in /usr/lib/python3/dist-packages (from pandas>=1.0->skan) (2019.3)
Collecting et-xmlfile
  Downloading et_xmlfile-1.1.0-py3-none-any.whl (4.7 kB)
Collecting llvmlite<0.39,>=0.38.0rc1
  Downloading llvmlite-0.38.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (34.5 MB)
     |████████████████████████████████| 34.5 MB 104.9 MB/s eta 0:00:01
Requirement already satisfied: setuptools in /usr/lib/python3/dist-packages (from numba>=0.50->skan) (45.2.0)
Requirement already satisfied: MarkupSafe>=2.0 in /home/catherine/.local/lib/python3.8/site-packages (from Jinja2>=2.10->numpydoc>=0.9.2->skan) (2.0.1)
Collecting alabaster<0.8,>=0.7
  Downloading alabaster-0.7.12-py2.py3-none-any.whl (14 kB)
Collecting sphinxcontrib-applehelp
  Downloading sphinxcontrib_applehelp-1.0.2-py2.py3-none-any.whl (121 kB)
     |████████████████████████████████| 121 kB 103.4 MB/s eta 0:00:01
Collecting sphinxcontrib-serializinghtml>=1.1.5
  Downloading sphinxcontrib_serializinghtml-1.1.5-py2.py3-none-any.whl (94 kB)
     |████████████████████████████████| 94 kB 7.4 MB/s  eta 0:00:01
Requirement already satisfied: babel>=1.3 in /home/catherine/.local/lib/python3.8/site-packages (from sphinx>=1.8->numpydoc>=0.9.2->skan) (2.9.1)
Collecting sphinxcontrib-jsmath
  Downloading sphinxcontrib_jsmath-1.0.1-py2.py3-none-any.whl (5.1 kB)
Collecting sphinxcontrib-devhelp
  Downloading sphinxcontrib_devhelp-1.0.2-py2.py3-none-any.whl (84 kB)
     |████████████████████████████████| 84 kB 7.5 MB/s  eta 0:00:01
Collecting snowballstemmer>=1.1
  Downloading snowballstemmer-2.2.0-py2.py3-none-any.whl (93 kB)
     |████████████████████████████████| 93 kB 3.2 MB/s s eta 0:00:01
Collecting imagesize
  Downloading imagesize-1.3.0-py2.py3-none-any.whl (5.2 kB)
Requirement already satisfied: importlib-metadata>=4.4; python_version < "3.10" in /home/catherine/.local/lib/python3.8/site-packages (from sphinx>=1.8->numpydoc>=0.9.2->skan) (4.8.2)
Collecting sphinxcontrib-htmlhelp>=2.0.0
  Downloading sphinxcontrib_htmlhelp-2.0.0-py2.py3-none-any.whl (100 kB)
     |████████████████████████████████| 100 kB 21.0 MB/s eta 0:00:01
Requirement already satisfied: requests>=2.5.0 in /usr/lib/python3/dist-packages (from sphinx>=1.8->numpydoc>=0.9.2->skan) (2.22.0)
Collecting docutils<0.18,>=0.14
  Downloading docutils-0.17.1-py2.py3-none-any.whl (575 kB)
     |████████████████████████████████| 575 kB 58.4 MB/s eta 0:00:01
Collecting sphinxcontrib-qthelp
  Downloading sphinxcontrib_qthelp-1.0.3-py2.py3-none-any.whl (90 kB)
     |████████████████████████████████| 90 kB 26.2 MB/s  eta 0:00:01
Requirement already satisfied: Pygments>=2.0 in /home/catherine/.local/lib/python3.8/site-packages (from sphinx>=1.8->numpydoc>=0.9.2->skan) (2.10.0)
Requirement already satisfied: zipp>=0.5 in /home/catherine/.local/lib/python3.8/site-packages (from importlib-metadata>=4.4; python_version < "3.10"->sphinx>=1.8->numpydoc>=0.9.2->skan) (3.6.0)
Installing collected packages: toolz, alabaster, sphinxcontrib-applehelp, sphinxcontrib-serializinghtml, sphinxcontrib-jsmath, sphinxcontrib-devhelp, snowballstemmer, imagesize, sphinxcontrib-htmlhelp, docutils, sphinxcontrib-qthelp, sphinx, numpydoc, et-xmlfile, openpyxl, tqdm, llvmlite, numba, skan
  Attempting uninstall: tqdm
    Found existing installation: tqdm 4.29.1
    Uninstalling tqdm-4.29.1:
      Successfully uninstalled tqdm-4.29.1
Successfully installed alabaster-0.7.12 docutils-0.17.1 et-xmlfile-1.1.0 imagesize-1.3.0 llvmlite-0.38.0 numba-0.55.1 numpydoc-1.2 openpyxl-3.0.9 skan-0.10.0 snowballstemmer-2.2.0 sphinx-4.4.0 sphinxcontrib-applehelp-1.0.2 sphinxcontrib-devhelp-1.0.2 sphinxcontrib-htmlhelp-2.0.0 sphinxcontrib-jsmath-1.0.1 sphinxcontrib-qthelp-1.0.3 sphinxcontrib-serializinghtml-1.1.5 toolz-0.11.2 tqdm-4.63.0

What is an image?¶

From the point of view of a computer, an image is a 2-dimensional matrix (gray tone images) or 3-dimensional (color images).

In [ ]:
matplotlib.rcParams.update({'font.size': 22})

###---Parameters---###
crop_size = 15
###----------------###

image_color = skimage.data.astronaut()
image_gray = np.mean(image_color, axis=2).astype(np.uint8) # RGB to Grayscale
output_shape = (image_gray.shape[0]//20, image_gray.shape[1]//20)
image_crop = image_gray[:crop_size,:crop_size]

fig, axs = plt.subplots(2, 2, figsize=(20,20))

axs[0,1].imshow(image_crop, cmap='gist_gray')
axs[1,1].imshow(image_color[:crop_size,:crop_size])
axs[0,1].set_title('Vue agrandie du coin de l\'image')
axs[1,1].set_title('Vue agrandie du coin de l\'image RGB')
for x in range(image_crop.shape[1]):
    for y in range(image_crop.shape[0]):
        rgb_pixel_intensity = image_color[y, x]
        grayscale_pixel_intensity = image_crop[y, x]
        adaptative_gray_color = 'black' if grayscale_pixel_intensity > 255/2 else 'white'

        axs[0,1].text(x, y, s=grayscale_pixel_intensity, 
            horizontalalignment='center', 
            verticalalignment='center',
            fontsize=14,
            color=adaptative_gray_color) # Gray channel
        axs[1,1].text(x, y-0.3, s=rgb_pixel_intensity[0], 
            horizontalalignment='center', 
            verticalalignment='center',
            fontsize=10,
            color='red') # Red channel
        axs[1,1].text(x, y, s=rgb_pixel_intensity[1], 
            horizontalalignment='center', 
            verticalalignment='center',
            fontsize=10,
            color='green') # Green channel
        axs[1,1].text(x, y+0.3, s=rgb_pixel_intensity[2],
            horizontalalignment='center', 
            verticalalignment='center',
            fontsize=10,
            color='blue') # Blue channel

axs[0,0].imshow(image_gray, cmap='gist_gray')
axs[0,0].set_title(f"Image Shape: {image_gray.shape}")

# Add rectangle patches of the cropped regions
rectG = matplotlib.patches.Rectangle(
    (0, 0), width=crop_size, height=crop_size,
    linewidth=5, edgecolor='red', facecolor='none')
rectRGB = matplotlib.patches.Rectangle(
    (0, 0), width=crop_size, height=crop_size, 
    linewidth=5, edgecolor='red', facecolor='none')
axs[0,0].add_patch(rectG)
axs[1,0].add_patch(rectRGB)

axs[1,0].imshow(image_color)
axs[1,0].set_title(f"Image Shape: {image_color.shape}")
axisoff(axs)
print()
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-1-550c3b5a4803> in <module>()
----> 1 matplotlib.rcParams.update({'font.size': 22})
      2 
      3 ###---Parameters---###
      4 crop_size = 15
      5 ###----------------###

NameError: name 'matplotlib' is not defined

Image Processing¶

There are countless possible image processing operations. For example, above we converted a RGB image with 3 channels to a one channel grayscale image.

We will start by introducing morphological operations, which are operations where each pixel in the image is adjusted based on the value of other pixels in its neighborhood.

Data¶

To test these operations, we will use a nice neuron image, credit: https://www.fiercebiotech.com/research/ucsd-targets-dendritic-spines-promising-alzheimer-s-approach

In [ ]:
response = requests.get('https://qtxasset.com/2016-04/neuron.jpg?PCK0DzNbfCngFqlsA9peo5TjhOAXJE0x')
img = skimage.io.imread(io.BytesIO(response.content))
fig, ax = plt.subplots(figsize=(12, 12))
ax.imshow(img)
ax.axis('off')
print()

Preprocessing Steps¶

Basic morphological operations are apply on a binary image, so we are talking about pixels with values of either 0 or 1.

To get there, we will apply a series of operations. Specifically, we are going to do the following:

  1. Convert from a RGB image to a grayscale image
  2. Apply a Gaussian filter to attenuate the image content to achieve a better quality mask (with $\sigma$ = 3)
  3. Find a good threshold to seperate the foreground from the background. We will use the threshold_triangle algorithm (see doc).
  4. Apply the threshold to get our 0 or 1 pixel values image. Voici les paramètres que nous allons utiliser

You can modify the parameters defined at the top of the cells and see by yourself the impact on the masks.

In [ ]:
# Parameters
GAUSSIAN_BLUR_SIGMA = 3

# 1. Convert to a grayscale image. We obtain values between 0 and 1
img_gray = skimage.color.rgb2gray(img)

# 2. Apply a gaussian filter with a sigma of GAUSSIAN_BLUR_SIGMA
img_blur = skimage.filters.gaussian(img_gray, sigma=GAUSSIAN_BLUR_SIGMA, mode='mirror')

# 3. Compute the threshold using triangle algorithm, otsu algorithm is doing well also.
threshold = skimage.filters.threshold_triangle(img_blur)
img_threshold = img_blur > threshold

fig, axes = plt.subplots(2, 2, sharex=True, sharey=True, tight_layout=True, figsize=(12, 12))
axes[0, 0].imshow(img)
axes[0, 0].set_title('Original RGB image', fontsize=22)
axes[0, 1].imshow(img_gray, cmap='gray')
axes[0, 1].set_title('Grayscale image', fontsize=22)
axes[1, 0].imshow(img_b, cmap='gray')
axes[1, 0].set_title(f"Blurred Image\n$\sigma$={GAUSSIAN_BLUR_SIGMA}", fontsize=22)
axes[1, 1].imshow(img_t, cmap='gray')
axes[1, 1].set_title(f"Binary Mask\nThreshold value: {threshold:0.3f}", fontsize=22)
axisoff(axes)
print()

Morphological Operations¶

Now that we have our binary image, we can test different morphological operations. There are also several types of morphological operations. Here we will focus on the 4 most popular.

  • Dilation: *inflate* the mask
  • Erosion: *erode* the mask
  • Opening: erosion followed by a dilation. Mostly used to remove small region.
  • Closing: dilation followed by an erosion. Mostly used to fill gaps in a mask.

We will be using a disk structuring element of $radius=5$. The structural element defined the neighborhood of pixel. More infos here.

You can modify the parameters defined at the top of the cells and see by yourself the impact on the masks.

In [ ]:
SELEM_RADIUS = 5

# Crop a region with the same ratio aspect of the original image
ratio = img.shape[0] / img.shape[1]
row, col = (600, 100) # location of the crop
crop_height = 400 # height of crop
crop_width = int(crop_height / ratio)
crop_slices = (
    slice(row, row + crop_height), 
    slice(col, col + crop_width)
    )

# The slices object is used to access the orignal image after padding
slices = (
    slice(SELEM_RADIUS, SELEM_RADIUS + img_threshold.shape[0]), 
    slice(SELEM_RADIUS, (SELEM_RADIUS + img_threshold.shape[1]))
    )

# Padding to keep the same image size after morphological operations
pad_values = (SELEM_RADIUS, SELEM_RADIUS)
img_threshold_pad = np.pad(img_threshold, (pad_values, pad_values), 'constant', constant_values=False)

# We create our structuring element as a disk of radius SELEM_RADIUS
selem = skimage.morphology.selem.disk(SELEM_RADIUS)

# Plot the structuring element
fig, ax = plt.subplots(figsize=(2, 2))
ax.imshow(selem, cmap='viridis')
ax.set_title('Structuring Element')
ax.axis('off')

# Define our 4 morphological operations
morph_ops = {
    'Dilation': skimage.morphology.binary_dilation,
    'Erosion': skimage.morphology.binary_erosion,
    'Opening': skimage.morphology.binary_opening,
    'Closing': skimage.morphology.binary_closing
}

# Show the operation on the whole image view as well as on the cropped region
fig, axes = plt.subplots(6, 2, tight_layout=True, figsize=[14, 30])
matplotlib.rcParams['font.size'] = 16
axes = axes.flatten()
axes[0].imshow(img)
axes[0].set_title('Original\nimage')
axes[6].imshow(img[crop_slices])
axes[1].imshow(img_threshold, cmap='gray')
axes[1].set_title('Thresholded\nimage')
axes[7].imshow(img_threshold[crop_slices], cmap='gray')
axes[7].set_title('Crop Thresholded\nimage')
for i, (ax, morphological_operation) in enumerate(zip(axes[2:6], morph_ops.keys())):
    im_morph = morph_ops[morphological_operation](img_threshold_pad, selem)
    ax.imshow(im_morph[slices], cmap='gray')
    ax.set_title(f'{morphological_operation} operation\nradius={SELEM_RADIUS}')
    axes[8 + i].imshow(im_morph[slices][crop_slices], cmap='gray')
    axes[8 + i].set_title(f'{morphological_operation} operation\nradius={SELEM_RADIUS}')
axisoff(axes)
print()

Skeletonization Operation¶

In this particular case, an interesting morphological operation would be skeletonization. The goal of this operation is to infer the underlying skeleton under the maksk. It is then possible to automatically quantify the branching of a neuron.

This can also be useful in other context. For instance, let's say you have implemented a segmentation tools for google maps. You could use the skeletonization to get a graph representation of roads !

But lets stick with neurons :)

For the foreground mask, we will be using the one generated after applying a dilation operation.

In [ ]:
# Simple function to add the rectangle patch around the cropped region
def add_crop_region(axes, slices):
    for ax in axes:
        height = slices[0].stop - slices[0].start
        width = slices[1].stop - slices[1].start
        row = slices[0].start
        col = slices[1].start
        rect = matplotlib.patches.Rectangle(
            (col, row), width=width, height=height, 
            linewidth=1, edgecolor='r', facecolor='none')
        ax.add_patch(rect)

# We will use the dilation operation before doing the skeletonization
selem = skimage.morphology.selem.disk(SELEM_RADIUS)
img_dilated = skimage.morphology.binary_dilation(img_threshold_pad, selem)

# Apply the skeletonization
img_skeletonize = skimage.morphology.skeletonize(img_dilated)[slices]
img_dilated = img_dilated[slices] # remove the padding

# Crop a region
ratio = img.shape[0] / img.shape[1]
row, col = (500, 400)
crop_height = 400
crop_width = int(crop_height / ratio)
crop_slices = (
    slice(row, row + crop_height), 
    slice(col, col + crop_width)
    )

fig, axes = plt.subplots(3, 3, tight_layout=True, figsize=(10, 10))
matplotlib.rcParams['font.size'] = 16
axes[0, 0].imshow(img)
axes[0, 0].set_title('Original\nRGB image')
axes[0, 1].imshow(img_blur, cmap='gray')
axes[0, 1].set_title(f'Blured image\nsigma={GAUSSIAN_BLUR_SIGMA}')
axes[1, 0].imshow(img_threshold, cmap='gray')
axes[1, 0].set_title('Thresholded\nimage')
axes[1, 1].imshow(img_dilated[slices], cmap='gray')
axes[1, 1].set_title(f'Dilated image\nradius={SELEM_RADIUS}')
axes[2, 0].imshow(img_skeletonize > 0, cmap='viridis')
axes[2, 0].set_title('Skeletonized\nimage')
skan.draw.overlay_skeleton_2d(img_g, img_skeletonize, axes=axes[2, 1])
axes[2, 1].set_title('Skeletonized\nimage')
axes[0, 2].imshow(img[crop_slices])
axes[1, 2].imshow(img_dilated[crop_slices], cmap='gray')
axes[2, 2].imshow(img_skeletonize[crop_slices] > 0, cmap='gray')
add_crop_region([axes[0, 0], axes[1, 1], axes[2, 0]], crop_slices)
axisoff(axes)
print()

Manipulation of the skeletonize image¶

The morphological operations are now done, we can now manipulate the skeletonize image and remove types of segment. We will be using the skan package (see doc https://jni.github.io/skan/).

This library will convert the skeletonize image as a graph. Each intersection of lines is a node in the graph. Each endpoint of lines is also a node in the graph. So their can be only 3 types of segments:

  • endpoint to endpoint
  • junction to endpoint
  • junction to junction

Here we will be keep segments following these rules:

  • The segment is bounded by 2 endpoints or by junction and endpoint
  • AND the segment is smaller than MINIMAL_LENGTH
  • OR the segment length is equal to the RESOLUTION
In [ ]:
# Parameters
RESOLUTION = 0.120 # Pixels dimension in um
MINIMAL_LENGTH = 10.0 # Minimal length of segments in um

# We define those variables to facilitate the understanding of the code.
ENDPOINT_TO_ENDPOINT = 0
JUNCTION_TO_ENDPOINT = 1
JUNCTION_TO_JUNCTION = 2
ISOLATE_CYCLE = 3

# Convert a skeleton image of thin lines to a graph of neighbor pixels.
pixel_graph, coordinates, degrees = skan.skeleton_to_csgraph(img_skeletonize)

# Create the skeleton obeject
skeleton = skan.Skeleton(img_skeletonize, spacing=RESOLUTION)

# Get branch data
branch_data = skan.summarize(skeleton)

# This loop over segments and check if they respect our rule
updated_img_skeletonize = img_skeletonize.copy()
for i, path in branch_data.iterrows():
    if (path['branch-type'] in [ENDPOINT_TO_ENDPOINT, JUNCTION_TO_ENDPOINT] \
    and path['branch-distance'] < MINIMAL_LENGTH) \
    or path['branch-distance'] == RESOLUTION:
        # This if statement verified 3 things:
        # 1. The segment is bounded by 2 endpoints or by junction and endpoint
        # 2. AND the segment is smaller than MINIMAL_LENGTH
        # 3. OR the segment length is equal to the RESOLUTION
        # If this statement is respected, we remove that segment
        coords = skeleton.path_coordinates(i).astype(np.int)
        updated_img_skeletonize[coords[:, 0], coords[:, 1]] = False

# Plotting 
fig, axes = plt.subplots(3, 3, tight_layout=True, figsize=(10, 10)) # sharex=True, sharey=True,
matplotlib.rcParams['font.size'] = 16
axes[0, 0].imshow(img)
axes[0, 0].set_title('Original\nRGB image')
axes[0, 1].imshow(img_blur, cmap='gray')
axes[0, 1].set_title(f'Blurred image\nsigma={GAUSSIAN_BLUR_SIGMA}')
axes[1, 0].imshow(img_threshold, cmap='gray')
axes[1, 0].set_title('Thresholded\nimage')
axes[1, 1].imshow(img_dilated[slices], cmap='gray')
axes[1, 1].set_title(f'Dilated image\nradius={SELEM_RADIUS}')
skan.draw.overlay_skeleton_2d(img_gray, img_skeletonize, axes=axes[2, 0])
axes[2, 0].set_title('Original\nskeletonized')
skan.draw.overlay_skeleton_2d(img_gray, updated_img_skeletonize, axes=axes[2, 1])
axes[2, 1].set_title('Modified\nskeletonized')
axes[0, 2].imshow(img[crop_slices])
axes[1, 2].imshow(img_skeletonize[crop_slices] > 0, cmap='gray')
axes[2, 2].imshow(updated_img_skeletonize[crop_slices] > 0, cmap='gray')
add_crop_region([axes[0, 0], axes[2, 0], axes[2, 1]], crop_slices)
axisoff(axes)
print()

Regions Analysis¶

Let's jump to another use-case. Here we want to count and characterize the number of instance in an image. In our case, we are interested by the cells present in this image.

In [ ]:
img = skimage.data.human_mitosis()

fig, ax = plt.subplots(figsize=(10, 10))
im_ax = ax.imshow(img, cmap='inferno')
fig.colorbar(im_ax)
ax.set_title('Microscopy image of human cells\nstained for nuclear DNA')
ax.axis('off')
print()
Downloading file 'data/mitosis.tif' from 'https://gitlab.com/scikit-image/data/-/raw/master/AS_09125_050116030001_D03f00d0.tif' to '/root/.cache/scikit-image/0.18.1'.

Processing Steps¶

In order to count the number of cells and characterize there shape. This is not a straigforward threshold application since cells can be overlapping. Here is our plan:

  1. Compute a threshold using the triangle algorithm to find the best spot to seperate the foreground from the background
  2. Apply the threshold to get a binary image
  3. Detect individual cells, which can be define as isolated regions. To do so, we will be using skimage.measure.label. More infos here. Each cell is defined by a unique label in the image.
  4. We will compute the Euclidean Distance Transform on the detected cells (more infos) which compute the distance to the nearest border for each pixel in a cell
  5. From the euclidean distance transform, we can find local maxima which can be approximated as center of cells. We set that no center can be closer than MIN_DISTANCE=3 pixels.
  6. Using those maxima, we will apply a watershed algorithm which will try to seperate cells masks. More infos on the whatershed algorithm.
In [ ]:
# Parameters
MIN_DISTANCE = 3

# 1. Compute the threshold 
threshold = skimage.filters.threshold_triangle(img)

# 2. Threshold the image with the computed threshold
cells = img > threshold

# 3. Detected individual cells
labeled_cells = skimage.measure.label(cells)
labeled_cells_rgb = skimage.color.label2rgb(labeled_cells, bg_label=0)

# 4. Compute the Euclidian
distance = scipy.ndimage.distance_transform_edt(cells)

# 5. Find local maxima
local_max_coords = skimage.feature.peak_local_max(distance, min_distance=MIN_DISTANCE)
local_max_mask = np.zeros_like(distance, dtype=bool)
local_max_mask[tuple(local_max_coords.T)] = True

# 6. Using the local maxima as markers, we will apply the watershed algo
markers = skimage.measure.label(local_max_mask)
segmented_cells = skimage.segmentation.watershed(-distance, markers, mask=cells)
segmented_cells_rgb = skimage.color.label2rgb(segmented_cells, bg_label=0)

fig, axs = plt.subplots(3, 2, figsize=(10, 14), tight_layout=True)
axs[0, 0].imshow(img, cmap='inferno')
axs[0, 0].set_title('Nuclear DNA')
axs[0, 1].imshow(cells, cmap='gray')
axs[0, 1].set_title(f'Threshold of {threshold:0.1f}')
axs[2, 0].imshow(labeled_cells_rgb)
axs[2, 0].set_title('Threshold\nIndividual Cells')
axs[1, 0].imshow(distance, cmap='gray')
axs[1, 0].set_title('Euclidian Distance\nTransform (EDT)')
axs[2, 1].imshow(segmented_cells_rgb)
axs[2, 1].set_title('Watershed\nIndividual Cells')
axs[1, 1].imshow(distance, cmap='gray')
axs[1, 1].scatter(local_max_coords[:, 1], local_max_coords[:, 0], 
                  s=10, marker='.', color='r')
axs[1, 1].set_title('Local Maximal Peaks')
axisoff(axs)
print()

In [ ]:
# Some zoomed regions
crop_height, crop_width = 50, 50
crop_corners = [
    [45, 110],
    [120, 140],
    [400, 60],
    [310, 190]
    ]
figsize = (12, 3 * len(crop_corners))
fig, axs = plt.subplots(len(crop_corners), 4, figsize=figsize, tight_layout=True)
for i, (row, col) in enumerate(crop_corners):
    slices = (slice(row, row + crop_height), slice(col, col + crop_width))
    axs[i, 0].imshow(img[slices], cmap='inferno')
    axs[i, 1].imshow(labeled_cells_rgb[slices])
    axs[i, 2].imshow(segmented_cells_rgb[slices])
    axs[i, 3].imshow(distance, cmap='gray')
    axs[i, 3].scatter(local_max_coords[:, 1], local_max_coords[:, 0], 
                  s=30, marker='.', color='r')
    axs[i, 3].set_xlim([slices[1].start, slices[1].stop])
    axs[i, 3].set_ylim([slices[0].stop, slices[0].start])
axisoff(axs)
print()

Cell Morphological Analysis¶

Now that we have identified our cells in the image. We can get a bunch of charateristics them and plot them !

In [ ]:
# Section pour manipuler les cellules individuellement
props_to_plot = ['area', 'max_intensity', 'eccentricity']
cells_props = skimage.measure.regionprops(segmented_cells, intensity_image=img)
cells_props_table = skimage.measure.regionprops_table(
    segmented_cells, 
    intensity_image=img,
    properties=props_to_plot)

props_to_plot = ['area', 'max_intensity', 'eccentricity']
figsize = (10, 4 * len(props_to_plot))
fig, axs = plt.subplots(len(props_to_plot), 1, figsize=figsize, tight_layout=True)
for ax, prop_to_plot in zip(axs, props_to_plot):
    ax.hist(cells_props_table[prop_to_plot], bins=100, color='black', density=True)
    ax.set_title(f'Cells {prop_to_plot}')
    ax.set_ylabel('Distribution')

The convolution¶

Convolution is the mathematical operation of convolutional neural networks (CNNs), the most widely used type of network in deep learning applied to images. A convolution is defined by a filter that is sequentially applied to subsections of the image of the same size as the filter. The filter is dragged over the entire image and at each step, the convolution result is the sum of the element-by-element multiplication of the filter and the image section.

The following section shows step by step how to apply a convolution filter to an image. You can change the parameters to see the effect on the result!

In [ ]:
matplotlib.rcParams.update({'font.size': 18})

im_color = skimage.data.cat()
im = np.mean(im_color[60:160,120:220],axis=2)

###--Parameters--###
kernel_size = 3 # must be odd here
stride = 1
padding = 1
###--------------###

# Définition d'un filtre de Sobel de taille kernel x kernel
def custom_sobel(shape, axis):
    """
    shape must be odd: eg. (5,5)
    axis is the direction, with 0 to positive x and 1 to positive y
    """
    k = np.zeros(shape)
    p = [(j,i) for j in range(shape[0]) 
           for i in range(shape[1]) 
           if not (i == (shape[1] -1)/2. and j == (shape[0] -1)/2.)]

    for j, i in p:
        j_ = int(j - (shape[0] -1)/2.)
        i_ = int(i - (shape[1] -1)/2.)
        k[j,i] = (i_ if axis==0 else j_)/float(i_*i_ + j_*j_)
    return k

filter = custom_sobel((kernel_size,kernel_size),0)

feature_map_dim0 = int((im.shape[0]+2*padding-(kernel_size-1)-1)/stride + 1)
feature_map_dim1 = int((im.shape[1]+2*padding-(kernel_size-1)-1)/stride + 1)

# Result that builds iteratively
conv_result = np.zeros((feature_map_dim0+2*padding,feature_map_dim1+2*padding))

fig, ax = plt.subplots(2,2,figsize=(10,10))
ax[0,0].imshow(filter,cmap='gist_gray')
for xx in range(filter.shape[0]):
  for yy in range(filter.shape[1]):
    if filter[xx,yy]>0:
      color = 'k'
    else:
      color = 'w'
    ax[0,0].text(yy,xx,s=filter[xx,yy],color=color,horizontalalignment='center',verticalalignment='center')
ax[0,0].set_title('Filtre')
ax[1,1].set_title('Résultat de la convolution')

for axi in ax.ravel():
	axi.axis('off')

time_limit = 30
time_before = time.time()
for i in range(feature_map_dim0-int(2*padding/stride)):
  for j in range(feature_map_dim1-int(2*padding/stride)):
    conv_result[i+padding,j+padding] = np.sum(np.matmul(filter, im[stride*i:stride*i+kernel_size, stride*j:stride*j+kernel_size]))
    # update display only every N steps, otherwise it is too slow
    if time.time()-time_before<time_limit:
      if j % 10 == 0:
        ax[1,0].cla()
        ax[1,0].imshow(im, cmap='gist_gray', vmin=0, vmax=255)
        ax[1,0].axis('off')
        conv = matplotlib.patches.Rectangle((-0.5+stride*j,-0.5+stride*i), width=kernel_size, height=kernel_size, linewidth=1, edgecolor='r', facecolor='none')
        ax[1,0].add_patch(conv)
        ax[1,0].set_title('Image')
        
        im_section = im[stride*i:stride*i+kernel_size, stride*j:stride*j+kernel_size]
        ax[0,1].cla()
        ax[0,1].imshow(im_section, cmap='gist_gray',vmin=0,vmax=255)
        ax[0,1].set_title('Section de l\'image')
        ax[0,1].axis('off')
        for xx in range(filter.shape[0]):
          for yy in range(filter.shape[1]):
            ax[0,1].text(yy,xx,s=im_section[xx,yy].astype(int),color=color,horizontalalignment='center',verticalalignment='center')

        ax[1,1].imshow(conv_result, cmap='gist_gray')
        display.display(plt.gcf())
        display.clear_output(wait=True)
        time.sleep(0.1)

ax[1,1].imshow(conv_result, cmap='gist_gray')
Out[ ]:
<matplotlib.image.AxesImage at 0x7fac71e01190>

Convolution parameters¶

The convolution is defined by 4 main parameters:

  • Kernel size: The size, in pixels, of one side of the filter. Although it is possible to apply a convolution with a non-square filter, it is generally the geometry that is used in the CNN.
  • Stride: The displacement in pixels between each operation
  • Padding: As long as the filter has a size greater than 1x1, the result of the convolution will have a size smaller than the original image. A padding on the border of the image is often added so that the result has the same size as the original image.
  • Dilation: The spacing between pixels in the section of the image to which the filter is applied.

These parameters are presented visually in the next section.

In [ ]:
crop_size = 15
im = np.mean(skimage.data.astronaut(),axis=2)[:crop_size,:crop_size]

#---Kernel size---#
kernel_size_list = [1,3,5,7]
#-----------------#

fig, axs = plt.subplots(4,4,figsize=(20,20))

for ax, kernel_size in zip(axs[0].ravel(), kernel_size_list):
  ax.imshow(im,cmap='gist_gray')
  filter = matplotlib.patches.Rectangle((-0.5,-0.5), width=kernel_size, height=kernel_size, linewidth=1, edgecolor='r', facecolor='none')
  ax.add_patch(filter)
  ax.axis('off')
  ax.set_ylim(14.5,-0.5)
  ax.set_title('Kernel size = {}x{}'.format(kernel_size,kernel_size))

#---Stride---#
stride_list = [1,2,3,4]
#------------#

for ax, stride in zip(axs[1].ravel(), stride_list):
  color_idx = 0
  color_range = len(range(0,im.shape[0]-2,stride))*len(range(0,im.shape[1]-2,stride))
  viridis = matplotlib.cm.get_cmap('viridis', color_range)
  ax.imshow(im, cmap='gist_gray')
  for i in range(0,im.shape[0]-2,stride):
    for j in range(0,im.shape[1]-2,stride):
      filter = matplotlib.patches.Rectangle((j-0.5,i-0.5), width=3, height=3, linewidth=1, edgecolor=viridis(color_idx), facecolor='none')
      color_idx += 1/color_range
      ax.add_patch(filter)
  ax.axis('off')
  ax.set_ylim(15.5,-0.5)
  ax.set_title('Stride = {}'.format(stride))

#---Padding---#
padding_list = [0,1,2,3]
#-------------#

filter = custom_sobel((7,7), 0)
conv_result = scipy.signal.convolve2d(im, filter, mode='valid')

for ax, padding in zip(axs[2].ravel(), padding_list):
  conv_result_padded = np.pad(conv_result, padding)
  ax.imshow(conv_result_padded, cmap='viridis')
  ax.set_ylim(15.5,-0.5)
  ax.set_xlim(15.5,-0.5)
  ax.set_title('Padding = {} px'.format(padding))
  ax.axis('off')

#---Dilation---#
dilation_list = [1,2,3,4]
#--------------#

for ax, dilation in zip(axs[3].ravel(), dilation_list):
  ax.imshow(im, cmap='gist_gray')
  for i in range(3):
    for j in range(3):
      px = matplotlib.patches.Rectangle((j-0.5+j*(dilation-1),i-0.5+i*(dilation-1)), width=1, height=1, linewidth=1, edgecolor='r', facecolor='none')
      ax.add_patch(px)
  ax.set_title('Dilation = {}'.format(dilation))

plt.show()

Why do we use convolution?¶

Convolution filters are used to highlight image features that span multiple pixels. The result of a convolution applied to an image is called a feature map. Feature maps will indicate the locations in the image that contain the element of interest defined by the filter: a circle, a straight vertical line, a point, etc. By applying several filters one after the other, the highlighted elements can become more complex: an eye, a cat's ear, a car wheel, etc.

The following code section shows the result of a simple filter, the Sobel filter, which allows you to highlight the horizontal, vertical and diagonal borders of an image.

In [ ]:
fig, axs = plt.subplots(2,4,figsize=(20,10))
im = np.mean(skimage.data.astronaut(),axis=2)

sobel_horizontal = np.array([[1,2,1], [0,0,0], [-1,-2,-1]])
sobel_vertical = np.array([[-1,0,1], [-2,0,2], [-1,0,1]])
sobel_diagonal1 = np.array([[0,1,2], [-1,0,1], [-2,-1,0]])
sobel_diagonal2 = np.array([[-2,-1,0], [-1,0,1], [0,1,2]])

filter_list = [sobel_horizontal, sobel_vertical, sobel_diagonal1, sobel_diagonal2]
filter_name_list = ['Horizontal', 'Vertical', 'Diagonal -45', 'Diagonal 45']

for idx, filter in enumerate(filter_list):
  axs[0,idx].imshow(filter, cmap='gist_gray')
  axs[1,idx].imshow(scipy.signal.convolve2d(im,filter), cmap='gist_gray')
  axs[0,idx].set_title(filter_name_list[idx])
  axs[0,idx].axis('off')
  axs[1,idx].axis('off')

plt.show()

Deep learning applied to images¶

Convolution neural network (CNN)¶

The convolution neural network sequentially applies convolutions to an image. The number of operations per layer, the number of layers, and the parameters of the convolution are fixed, but the weights and biases of the convolutions are learned in training.

The following code section illustrates the learning of these weights and biases by comparing randomly initialized AlexNet convolution filters and these same filters after training.

In [ ]:
def plot_weights(model, layer_num, title=""):
  
  #extracting the model features at the particular layer number
  layer = model.features[layer_num]
  
  #checking whether the layer is convolution layer or not 
  if isinstance(layer, nn.Conv2d):
    #getting the weight tensor data
    weight_tensor = model.features[layer_num].weight.data

    #normalize weight_tensor between 0 and 1.0 for visualization
    weight_tensor = ((weight_tensor-weight_tensor.min()) / (weight_tensor.max()-weight_tensor.min()) * 1.0)

    #switch order of channels to fit with matplotlib (torch uses CxHxW, matplotlib uses HxWxC)
    weight_tensor = weight_tensor.permute(0,2,3,1)
    
    fig, axs = plt.subplots(4,16, figsize=(20,5)) # we stop at 64 filters
    for idx in range(min(weight_tensor.shape[0],64)):
      axs.ravel()[idx].imshow(weight_tensor[idx])
      axs.ravel()[idx].text(0,-1,s=str(idx))
      axs.ravel()[idx].axis('off')
    plt.suptitle(title)
    plt.show()        
  else:
    print("Can only visualize layers which are convolutional")
        
#visualize weights for alexnet - first convolution 
alexnet = models.alexnet(pretrained=False)
plot_weights(alexnet, 0, title="Random filters")
alexnet_trained = models.alexnet(pretrained=True)
plot_weights(alexnet_trained, 0, title="Learned filters")
Downloading: "https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-4df8aa71.pth

Maintenant, visualisons la résultat de ces filtres appliqués sur une image! On remarque bien qu'un filtre aléatoire ne met pas en valeur d'éléments précis, mais les filtres appris font ressortir des éléments de base comme les lignes, les ronds, les motifs quadrillés, etc.

Changez la variable "filtre_idx" pour appliquer le filtre de votre choix, en référence aux numéros de filtre de la cellule précédente.

In [ ]:
#---Choose the filter to apply---#
filter_idx = 2
#--------------------------------#

# Image
im = torch.tensor(skimage.data.astronaut()).unsqueeze(0)
im = im.permute(0,3,1,2).float()

# Aplly first convolution layer from AlexNet to image
im_conv_nottrained = alexnet.features[:1](im)
im_conv_trained = alexnet_trained.features[:1](im)

fig, axs = plt.subplots(2,2,figsize=(10,10))

# Random filter
filter_nottrained = alexnet.features[0].weight[filter_idx].permute(1,2,0).detach().numpy()
filter_nottrained = (filter_nottrained-filter_nottrained.min())/(filter_nottrained.max()-filter_nottrained.min())
axs[0,0].imshow(filter_nottrained)
axs[0,0].set_title('Random filter')
axs[0,1].imshow(im_conv_nottrained[0,filter_idx,...].detach().numpy(),cmap='gist_gray')
axs[0,1].set_title('Feature map')

# Trained filter
filter_trained = alexnet_trained.features[0].weight[filter_idx].permute(1,2,0).detach().numpy()
filter_trained = (filter_trained-filter_trained.min())/(filter_trained.max()-filter_trained.min())
axs[1,0].imshow(filter_trained)
axs[1,0].set_title('Trained filter')
axs[1,1].imshow(im_conv_trained[0,filter_idx,...].detach().numpy(),cmap='gist_gray')
axs[1,1].set_title('Feature map')

for ax in axs.ravel():
  ax.axis('off')
In [ ]:
# Display all feature maps from first convolution layer of AlexNet

fig, axs = plt.subplots(8,8,figsize=(10,10))
for i, ax in enumerate(axs.ravel()):
  ax.imshow(im_conv_trained[0,i,...].detach().numpy(),cmap='gist_gray')
  ax.axis('off')

fig, axs = plt.subplots(1, 2, figsize=(10, 5))
for i, (ax, to_show) in enumerate(zip(axs, (2, 21))):
  ax.imshow(im_conv_trained[0, to_show].detach(), cmap='gist_gray')
  ax.axis('off')
axs[0].annotate('Edge?', (50, 87), (70, 60), 
                color='w',
                arrowprops={'arrowstyle':'->', 'linewidth': 2, 'color':'w'},
                bbox=dict(facecolor='black', alpha=0.5, edgecolor='w', boxstyle='round'))
axs[1].annotate('Lip?', (55, 38), (50, 85), 
                color='w',
                arrowprops={'arrowstyle':'->', 'linewidth': 2, 'color':'w'},
                bbox=dict(facecolor='black', alpha=0.5, edgecolor='w', boxstyle='round'))
Out[ ]:
Text(50, 85, 'Lèvre?')

Automatic Feature Extraction¶

The values of the filters of the convolution network are learned by comparing the output of the network to the ground truth. The ground truth are the labels created manually by an expert a priori of the network training (supervised training).

In the following code cell, we can visualize the filters learned by a network. By increasing the cnn_layer variable to visualize the deeper layers of the network, we can clearly observe that the features highlighted by the filters become more complex by the application of several filters in cascade. Conversely, for the first layers of the network, the features are much simpler.

Here, the network is already trained, so we do not optimize its weights. Instead, we optimize the input image to maximize the activation of the chosen filter. For more information, I absolutely recommend to consult the source of the method: How to visualize convolutional features in 40 lines of code.

In [ ]:
# Source: https://github.com/utkuozbulak/pytorch-cnn-visualizations
# Source: https://towardsdatascience.com/how-to-visualize-convolutional-features-in-40-lines-of-code-70b7d87b0030

def save_image(im, path):
    """
        Saves a numpy matrix or PIL image as an image
    Args:
        im_as_arr (Numpy array): Matrix of shape DxWxH
        path (str): Path to the image
    """
    if isinstance(im, (np.ndarray, np.generic)):
        im = format_np_output(im)
        im = Image.fromarray(im)
    im.save(path)


def preprocess_image(pil_im, resize_im=True):
    """
        Processes image for CNNs
    Args:
        PIL_img (PIL_img): PIL Image or numpy array to process
        resize_im (bool): Resize to 224 or not
    returns:
        im_as_var (torch variable): Variable that contains processed float tensor
    """
    # mean and std list for channels (Imagenet)
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    #ensure or transform incoming image to PIL image
    if type(pil_im) != Image.Image:
        try:
            pil_im = Image.fromarray(pil_im)
        except Exception as e:
            print("could not transform PIL_img to a PIL Image object. Please check input.")

    # Resize image
    if resize_im:
        pil_im = pil_im.resize((224, 224), Image.ANTIALIAS)

    im_as_arr = np.float32(pil_im)
    im_as_arr = im_as_arr.transpose(2, 0, 1)  # Convert array to D,W,H
    # Normalize the channels
    for channel, _ in enumerate(im_as_arr):
        im_as_arr[channel] /= 255
        im_as_arr[channel] -= mean[channel]
        im_as_arr[channel] /= std[channel]
    # Convert to float tensor
    im_as_ten = torch.from_numpy(im_as_arr).float()
    # Add one more channel to the beginning. Tensor shape = 1,3,224,224
    im_as_ten.unsqueeze_(0)
    # Convert to Pytorch variable
    im_as_var = Variable(im_as_ten, requires_grad=True)
    return im_as_var


def recreate_image(im_as_var):
    """
        Recreates images from a torch variable, sort of reverse preprocessing
    Args:
        im_as_var (torch variable): Image to recreate
    returns:
        recreated_im (numpy arr): Recreated image in array
    """
    reverse_mean = [-0.485, -0.456, -0.406]
    reverse_std = [1/0.229, 1/0.224, 1/0.225]
    recreated_im = copy.copy(im_as_var.data.numpy()[0])
    for c in range(3):
        recreated_im[c] /= reverse_std[c]
        recreated_im[c] -= reverse_mean[c]
    recreated_im[recreated_im > 1] = 1
    recreated_im[recreated_im < 0] = 0
    recreated_im = np.round(recreated_im * 255)

    recreated_im = np.uint8(recreated_im).transpose(1, 2, 0)
    return recreated_im

def format_np_output(np_arr):
    """
        This is a (kind of) bandaid fix to streamline saving procedure.
        It converts all the outputs to the same format which is 3xWxH
        with using sucecssive if clauses.
    Args:
        im_as_arr (Numpy array): Matrix of shape 1xWxH or WxH or 3xWxH
    """
    # Phase/Case 1: The np arr only has 2 dimensions
    # Result: Add a dimension at the beginning
    if len(np_arr.shape) == 2:
        np_arr = np.expand_dims(np_arr, axis=0)
    # Phase/Case 2: Np arr has only 1 channel (assuming first dim is channel)
    # Result: Repeat first channel and convert 1xWxH to 3xWxH
    if np_arr.shape[0] == 1:
        np_arr = np.repeat(np_arr, 3, axis=0)
    # Phase/Case 3: Np arr is of shape 3xWxH
    # Result: Convert it to WxHx3 in order to make it saveable by PIL
    if np_arr.shape[0] == 3:
        np_arr = np_arr.transpose(1, 2, 0)
    # Phase/Case 4: NP arr is normalized between 0-1
    # Result: Multiply with 255 and change type to make it saveable by PIL
    if np.max(np_arr) <= 1:
        np_arr = (np_arr*255).astype(np.uint8)
    return np_arr


class CNNLayerVisualization():
    """
        Produces an image that minimizes the loss of a convolution
        operation for a specific layer and filter
    """
    def __init__(self, model, selected_layer, selected_filter):
        self.model = model
        self.model.eval()
        self.selected_layer = selected_layer
        self.selected_filter = selected_filter
        self.conv_output = 0

    def hook_layer(self):
        def hook_function(module, grad_in, grad_out):
            # Gets the conv output of the selected filter (from selected layer)
            self.conv_output = grad_out[0, self.selected_filter]
        # Hook the selected layer
        self.model[self.selected_layer].register_forward_hook(hook_function)

    def visualise_layer_with_hooks(self):
        # Hook the selected layer
        self.hook_layer()
        # Generate a random image
        sz = 64
        random_image = np.uint8(np.random.uniform(150, 180, (sz, sz, 3)))
        # Process image and return variable
        processed_image = preprocess_image(random_image, False)
        for j in range(5): # Increase the number of iterations to get a better resolution (but requires more computing time)
          # Define optimizer for the image
          optimizer = Adam([processed_image], lr=0.1, weight_decay=1e-6)
          for i in range(1, 11):
              optimizer.zero_grad()
              # Assign create image to a variable to move forward in the model
              x = processed_image
              for index, layer in enumerate(self.model):
                  # Forward pass layer by layer
                  # x is not used after this point because it is only needed to trigger
                  # the forward hook function
                  x = layer(x)
                  # Only need to forward until the selected layer is reached
                  if index == self.selected_layer:
                      # (forward hook function triggered)
                      break
              # Loss function is the mean of the output of the selected layer/filter
              # We try to minimize the mean of the output of that specific filter
              loss = -torch.mean(self.conv_output)
              # Backward
              loss.backward()
              # Update image
              optimizer.step()
          print('Iteration:', str(j), 'Loss:', "{0:.2f}".format(loss.data.numpy()))
          # Rescale image by factor 1.2
          sz = int(1.2 * sz)
          processed_image = processed_image.detach().numpy()
          processed_image = processed_image[0].transpose(1,2,0)
          processed_image = cv2.resize(processed_image, (sz, sz), interpolation = cv2.INTER_CUBIC)
          processed_image = torch.tensor(processed_image).permute(2,0,1).unsqueeze(0)
          processed_image.requires_grad = True
        # Recreate image
        self.created_image = recreate_image(processed_image)
        fig, ax = plt.subplots(1,1,figsize=(10,10))
        ax.imshow(self.created_image)
        ax.axis('off')

#---Paramètres---#
# La complexité des filtres augmente avec la prodondeur de la convolution (cnn_layer)
cnn_layer = 24 # choose between 0, 2, 5, 7, 10, 12, 14, 17, 19, 21, 24
filter_pos = 1 # [0,63] if cnn_layer = (0,2), [0,127] if cnn_layer = (5,7), [0,255] if cnn_layer = (10,12,14), [0,511] otherwise
#----------------#

# Fully connected layer is not needed
pretrained_model = models.vgg16(pretrained=True).features
layer_vis = CNNLayerVisualization(pretrained_model, cnn_layer, filter_pos)

# Layer visualization with pytorch hooks
layer_vis.visualise_layer_with_hooks()
Iteration: 0 Loss: -36.66
Iteration: 1 Loss: -58.07
Iteration: 2 Loss: -78.52
Iteration: 3 Loss: -88.35
Iteration: 4 Loss: -88.06

Pour plus d'exemples et l'explication du fonctionnement: How to visualize convolutional features in 40 lines of code

Gradient activation view¶

When a network has been trained, a naive way to see what the network is looking when making a prediction is to look at the gradient values for the specific class.

In [ ]:
matplotlib.rcParams.update({'font.size': 20})

torch.manual_seed(42)

# Download ImageNet labels
!wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt

# Read the categories
with open("imagenet_classes.txt", "r") as f:
    categories = [s.strip() for s in f.readlines()]

# Load a dog image from pytorch site

url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
urllib.request.urlretrieve(url, filename)

# Dog image in the range 0 - 255
dog = Image.open(filename)

# Define image transformations
alexnet_stats = {
    'mean': [0.485, 0.456, 0.406], 
    'std': [0.229, 0.224, 0.225]
} # see https://pytorch.org/hub/pytorch_vision_alexnet/
reshape = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
])
normalize = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(**alexnet_stats),
])

# Reshape the dog image to 224x224
dog = reshape(dog)
reshaped_dog = dog.copy()
# Normalize the intensity value of the dog image
dog = normalize(dog)
dog = dog.unsqueeze(0) # Create a mini-batch of 1 element

# Set the dog image to require gradients
dog.requires_grad = True

# Apply AlexNet to get a prediction on the image
output = alexnet_trained(dog)[0]

# ouput is a Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
# The output has unnormalized scores (not between 0 and 1). 
# To get probabilities, we will run a softmax on it.
probabilities = torch.nn.functional.softmax(output, dim=0)

# We will get the indice with the highest probability
highest_probability_indice = torch.argmax(output)

# From that indice, we can retreive the predicted categorie 
predicted_category = categories[highest_probability_indice]

# We will also retrieve the probability the network gave to this category
predicted_probability = probabilities[highest_probability_indice]

# Since we set the input image to require gradients, we can backprogated
# the gradients associated to the highest category to get a glimpse of which 
# are influencing the prediction the most
output[highest_probability_indice].backward()
gradients = dog.grad.squeeze(0).numpy()
gradients = np.max(gradients, axis=0)

# Plotting
fig, axs = plt.subplots(1, 3, figsize=((18, 6)))
axs[0].imshow(reshaped_dog)
axs[0].set_title(f'{predicted_category} at {predicted_probability*100:.1f}%', fontsize=20)
axs[1].imshow(gradients, cmap='viridis')
axs[1].set_title('Raw Gradients', fontsize=20)
axs[2].imshow(reshaped_dog)

# Get a smoothed density map
x = np.arange(0, 224)
X, Y = np.meshgrid(x, x)
gradients[gradients < 0] = 0
axs[2].contourf(X, Y, gradients, alpha=0.5, levels=5)
axs[2].set_title('Gradients > 0', fontsize=20)

for ax in axs:
    ax.axis('off')
--2021-04-30 13:12:59--  https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 10472 (10K) [text/plain]
Saving to: ‘imagenet_classes.txt’

imagenet_classes.tx 100%[===================>]  10.23K  --.-KB/s    in 0s      

2021-04-30 13:12:59 (100 MB/s) - ‘imagenet_classes.txt’ saved [10472/10472]

In [ ]:
# Create a simple convolutional network without a final classification layer
def simple_network(classification_layer=False):
    modules = []
    modules.extend([
        nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1),
        nn.AvgPool2d(2),
        nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1),
        nn.AvgPool2d(2),
        nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1),
        nn.AvgPool2d(2),
        nn.Flatten()
    ])
    if classification_layer:
        modules.append(nn.Linear(32 * 28 * 28, 10))
    return nn.Sequential(*modules)

fig, axs = plt.subplots(1, 3, figsize=((12, 6)), tight_layout=True)
for ax_id, classification_layer, n_samples in zip([1, 2], [False, True], [25, 1]):

    # Taking the dog example again, we set gradients to zero
    dog.grad.zero_()

    # Foward pass in our simple network
    net = simple_network(classification_layer=classification_layer)
    output = net(dog)[0]

    # Random receptive activations
    for indice in np.random.randint(0, output.size(), n_samples):
        output[indice].backward(retain_graph=True)

    # Max out gradients over RGB channels
    gradients = dog.grad.squeeze(0).numpy()
    gradients = np.max(gradients, axis=0)

    # Plotting

    axs[ax_id].imshow(reshaped_dog)

    # Get a smoothed density map
    x = np.arange(0, 224)
    X, Y = np.meshgrid(x, x)
    axs[ax_id].contourf(X, Y, gradients, alpha=0.5, levels=5)
    with_without = 'With' if classification_layer else 'Without'
    axs[ax_id].set_title(f'Random Receptive Fields\n{with_without} Fully Connected\nLayer', fontsize=20)
    axs[ax_id].axis('off')

axs[0].imshow(reshaped_dog)
axs[0].set_title('Input image', fontsize=20)
axs[0].axis('off')
Out[ ]:
(-0.5, 223.5, 223.5, -0.5)

Possible Tasks¶

Multiple tasks can be adresssed with convolutional neural networks in vision.

Classification¶

As seen in the presentation from Frederic Paradis, the classification task of images can now be considered solved !

image.png

Segmentation¶

In a segmentation task, we are interested in classifying all the pixels of an input image. This is called a dense segmentation. In this case, the possible classes for each pixel are also defined.

To do this, just like for the classification task, the first section is used for feature extraction. These allow the network to project the input image into an abstract feature space. From this space, operations are performed to return to the original space, to obtain a spatial representation of what the network has detected in the original image.

This is an example of such network.

In [ ]:
# Pytorch U-Net examples is not working properly (see https://github.com/mateuszbuda/brain-segmentation-pytorch/issues/20#issue-752962124)
# Let's just download the brain slice and the mask
# Download an example image and masks
url = 'https://drive.google.com/uc?export=view&id=1TZxLk8XrwN8CDMQ7yj2PilEkwjvf7SLs'
url_mask = 'https://drive.google.com/uc?export=view&id=1O9HnzE3yLgCHOH7909ZQioKrW3yVNamG'
filename = 'brain.tif'
maskname = 'mask.tif'
urllib.request.urlretrieve(url, filename)
urllib.request.urlretrieve(url_mask, maskname)

brain_image = Image.open(filename)
brain_mask = np.array(Image.open(maskname)) > 0
borders = scipy.ndimage.binary_dilation(brain_mask) ^ brain_mask

fig, axs = plt.subplots(1, 3, figsize=((18, 6)))
axs[0].imshow(brain_image)
axs[0].set_title('Brain Slice', fontsize=20)
axs[0].axis('off')
axs[1].imshow(brain_mask, cmap='gray')
axs[1].set_title('Segmentation of Low-Grade\nGlioma (Tumour)', fontsize=20)
axs[1].axis('off')
brain_image = np.array(brain_image).copy()
brain_image[borders] = 255
axs[2].imshow(brain_image)
axs[2].set_title('Segmentation of Low-Grade\nGlioma (Tumour)', fontsize=20)
axisoff(axs)
print()

Image generation and Image-to-Image translation¶

La traduction d'images d'un domaine vers un autre a son lot d'applications, pour la plupart ludiques:

  • Image colorization (pour l'essayer: Image Colorization API

  • Generation of maps from satellite images: Pix2Pix

  • Turning horses into zebras: Cycle-GAN

  • Super-resolution:

Image generation has similar applications, but the images are created randomly, from scratch, rather than translated from a real image:

  • Generation of human faces (ThisPersonDoesNotExist)):

  • Generation of cat faces (ThisCatDoesNotExist):

  • ... and so much more: (ThisXDoesNotExist)

All the presented applications were produced using adversarial generative networks (GAN). This architecture is composed of two simultaneously trained networks: a generator and a discriminator.

  • Generator: fully convolutional network that builds an image from a random value (GAN) or from an image (conditional GAN).

  • Discriminator: convolutional classification network, whose task is to differentiate real images from generated images.


source

Practical points of reference for starting a deep learning project¶


Check if deep learning is appropriate for the task¶

  • First try simpler and faster machine learning methods to test:
    • Support Vector Machines (SVM)

    • Random forest


    • Principal Component Analysis (PCA)

    • Threshold based segmentation)


    • Clustering
    • Other classical methods.

Define the task¶

  • Segmentation? Classification? Detection? Generation?
  • How many classes? What type of segmentation, binary or semantic?
  • How will performance be evaluated? Should false positive be as impactful as false negative?

Prepare the data¶

  • How many images are available? What annotations are available?
  • Are the classes unbalanced? Should they be?
  • What data augmentation methods are applicable? For example, numbers are not rotation invariant (a 6 upside down should be classified as a 9, not a 6).




Choose the network architecture¶

  • Pre-trained models are available in the [torchvision] library (https://pytorch.org/vision/stable/models.html)
  • Consult the literature for the best model of the moment for a particular task:

    • Classification, many images: ResNeXt-101-32x8d
    • Classification, fewer images: VGG-11, AlexNet
    • Segmentation: U-Net
    • Instance detection: Mask-RCNN
  • Warning: deeper networks are generally more efficient, but training them requires much more data and computational resources.

Hyperparameters Selection¶

  • Look at published papers that use a similar method for a similar task and examine the hyperparameters used.
  • If training is resource and time efficient, replicate several trainings with varying hyperparameters (grid search).
  • Examine the evolution of the validation and training losses; if the losses oscillate a lot between epochs, you may consider reducing the learning rate or increasing the batch size. If the loss converges but the actual performance does not, consider using a different objective function better suited to your specific task.

Ressources¶

Image analysis¶

  • Python libraries: 10 Python image manipulation tools
    • scikit-image (Free collection of algorithms for image processing)
    • Pillow ()
    • numpy (Remember: images are matrices!)
  • Pytorch tutorials
  • Plugins for ImageJ : https://imagej.net/Category:Plugins

Deep learning¶

  • Pytorch
  • For a theoretical understanding : Deep Learning Book
  • In-depth publications about data science, often with code snippets to help understand the concepts explained : Towards Data Science
  • Youtube channel 3Blue1Brown

Computing ressources¶

  • Google Colab (Zero configuration required, Free access to GPUs, Easy sharing)